{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Creating our first agent with Stable Baseline\n",
    "\n",
    "Note that currently, Stable Baselines works only with TensorFlow version 1.x. So,\n",
    "make sure you are running the Stable Baselines experiment with TensorFlow 1.x.\n",
    "\n",
    "Now, let's create our first deep reinforcement learning algorithm using baseline. Let's create\n",
    "a simple agent using deep Q network for the mountain car climbing task. We know that\n",
    "in the mountain car climbing task, a car is placed between the two mountains and the goal\n",
    "of the agent is to drive up the mountain on the right.\n",
    "\n",
    "First, let's import the gym and DQN from the stable baselines:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import warnings\n",
    "warnings.filterwarnings('ignore')\n",
    "                        \n",
    "import gym\n",
    "from stable_baselines import DQN"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Create a mountain car environment:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "env = gym.make('MountainCar-v0')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now, let's instantiate our agent. As we can observe in the below code, we are passing the\n",
    "`MlpPolicy`, it implies that our network is a multilayer perceptron. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "agent = DQN('MlpPolicy', env, learning_rate=1e-3)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's train the agent by specifying the number of time steps we want to train: "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<stable_baselines.deepq.dqn.DQN at 0x7f4190078240>"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "agent.learn(total_timesteps=25000)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "That's it. Building a DQN agent and training them is that simple.\n",
    "\n",
    "## Evaluating the trained agent\n",
    "\n",
    "We can also evaluate the trained agent by looking at the mean rewards using\n",
    "`evaluate_policy`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "from stable_baselines.common.evaluation import evaluate_policy"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In the below code, agent is the trained agent, `agent.get_env()` gets the environment we\n",
    "trained our agent with, `n_eval_episodes` implies the number of episodes we need to\n",
    "evaluate our agent:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "mean_reward, n_steps = evaluate_policy(agent, agent.get_env(), n_eval_episodes=10)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Storing and loading the trained agent\n",
    "\n",
    "With stable baselines, we can also save our trained agent and load them.\n",
    "We can save the agent as:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "agent.save(\"DQN_mountain_car_agent\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "After saving, we can load the agent as:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "agent = DQN.load(\"DQN_mountain_car_agent\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Viewing the trained agent\n",
    "\n",
    "After training, we can also have a look at how our trained agent performs in the\n",
    "environment.\n",
    "\n",
    "Initialize the state:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "state = env.reset()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#for some 5000 steps:\n",
    "for t in range(5000):\n",
    "    \n",
    "    #predict the action to perform in the given state using our trained agent:\n",
    "    action, _ = agent.predict(state)\n",
    "    \n",
    "    #perform the predicted action\n",
    "    next_state, reward, done, info = env.step(action)\n",
    "    \n",
    "    #update next state to current state \n",
    "    state = next_state\n",
    "    \n",
    "    #render the environment\n",
    "    env.render()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
